#include "instance_state.h"

namespace triton { namespace backend { namespace ascend {

TRITONSERVER_Error*
ModelInstanceState::Create(
    ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance,
    ModelInstanceState** state)
{
    try {
        *state = new(std::nothrow) ModelInstanceState(model_state, triton_model_instance);
    }
    catch (const BackendModelInstanceException& ex) {
        RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
            std::string("unexpected nullptr in BackendModelInstanceException"));
        RETURN_IF_ERROR(ex.err_);
    }

    return nullptr;  // success
}

ModelInstanceState::ModelInstanceState(
    ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance)
    : BackendModelInstance(model_state, triton_model_instance),
    model_state_(model_state), device_(-1)
{

    if (Kind() != TRITONSERVER_INSTANCEGROUPKIND_NPU) {
        THROW_IF_BACKEND_INSTANCE_ERROR(
            TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            std::string("backend 'ascend' only support device 'NPU'").c_str())
        );
    }
 
    device_ = DeviceId();
    THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(ArtifactFilename(),
        device_, &model_path_, &ascend_model_));
    
    // Tode 针对序列化输入模型的处理
    supports_batching_ = model_state->MaxBatchSize() > 0;
    
    dynamic_info_ = model_state->GetDynamicInfo();

    input_format_ = model_state_->GetInputFormat();

    output_format_ = model_state_->GetOutputFormat();

    THROW_IF_BACKEND_INSTANCE_ERROR(VaildateInputs());
    THROW_IF_BACKEND_INSTANCE_ERROR(VaildateOutputs());
}

ModelInstanceState::~ModelInstanceState()
{
    MxBase::DeviceContext context = {};
    context.devId = static_cast<int>(device_);
    MxBase::DeviceManager::GetInstance()->SetDevice(context);
    
    ascend_model_.reset();
    std::cout << "model reset" << std::endl;
    
}

void
ModelInstanceState::ProcessRequests(
    TRITONBACKEND_Request** requests, const uint32_t request_count)
{
    LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
        (std::string("TRITONBACKEND_ModelExecute: Running ") + Name() +" in NPU '" +
        std::to_string(device_) + "' with " + std::to_string(request_count) + " requests").c_str());

    // 计时起始时间
    uint64_t exec_start_ns = 0;
    SET_TIMESTAMP(exec_start_ns);
    // 读取config里设置的最大batchsize
    const int max_batch_size = model_state_->MaxBatchSize();
    // For each request collect the total batch size for this inference
    // execution. The batch-size, number of inputs, and size of each
    // input has already been checked so don't need to do that here.
    size_t total_batch_size = 0;
    // 检查输入的request是否为空
    for (size_t i = 0; i < request_count; i++) {
        // If we get a nullptr request then something is badly wrong. Fail
        // and release all requests.
        if (requests[i] == nullptr) {
            RequestsRespondWithError(requests, request_count,
                TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                    std::string("null request given to Ascend backend for '" + Name() + "'").c_str()));
            return;
        }
    }

    // 创建responses
    std::vector<TRITONBACKEND_Response*> responses;
    responses.reserve(request_count);
    bool all_response_failed = false;

    for (size_t i = 0; i < request_count; i++) {
        TRITONBACKEND_Response* response;
        auto err = TRITONBACKEND_ResponseNew(&response, requests[i]);
        if (err == nullptr) {
            responses.emplace_back(response);
        } else {
            responses.emplace_back(nullptr);
            LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "failed to create response.");
            TRITONSERVER_ErrorDelete(err);
        }
    }
    // 对request进行处理
    for (size_t i = 0; i < request_count; i++) {
        // 如果支持合并batch
        if (supports_batching_) {
            // Retrieve the batch size from one of the inputs, if the model
            // supports batching, the first dimension size is batch size.
            TRITONBACKEND_Input* input;
            TRITONSERVER_Error* err = TRITONBACKEND_RequestInputByIndex(requests[i], 0 /* index */, &input);
            if (err == nullptr) {
                const int64_t* shape;
                err = TRITONBACKEND_InputProperties(input, nullptr, nullptr, &shape, nullptr, nullptr, nullptr);
                total_batch_size += shape[0];
            }
            if (err != nullptr) {
                RESPOND_ALL_AND_SET_TRUE_IF_ERROR(responses, request_count, all_response_failed, err);
            }
        } else {
            total_batch_size += 1;
        }
    }

    // If there are no valid payloads then no need to run the inference.
    if (total_batch_size == 0) {
        return;
    }

    // Make sure the maximum batch size is not exceeded. The
    // total_batch_size must be 1 for models that don't support batching
    // (i.e. max_batch_size == 0). If max_batch_size is exceeded then
    // scheduler has done something badly wrong so fail and release all
    // requests.
    if (!all_response_failed) {
        if ((total_batch_size != 1) && (total_batch_size > (size_t)max_batch_size)) {
            RESPOND_ALL_AND_SET_TRUE_IF_ERROR(responses, request_count, all_response_failed,
                TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                    std::string("batch size " + std::to_string(total_batch_size) + " for '" +
                        Name() + "', max allowed is " + std::to_string(max_batch_size)).c_str()));
        }
    }
    // 此处实现多request的内存合并
    std::vector<MxBase::Tensor> input_tensors;
    std::vector<const char*> input_memories;
    std::unique_ptr<BackendInputCollector> collector;
    if (!all_response_failed) {
        collector.reset(new(std::nothrow) BackendInputCollector(requests, request_count, &responses,
            model_state_->TritonMemoryManager(), model_state_->EnablePinnedInput(),
            nullptr /* stream */)
        );
        
        RESPOND_ALL_AND_SET_TRUE_IF_ERROR(responses, request_count, all_response_failed,
            SetInputTensors(total_batch_size, requests, request_count, collector.get(),
            input_tensors, input_memories)
        );
    }
    
    std::vector<MxBase::Tensor> output_tensors;
    uint64_t compute_start_ns = 0;
    SET_TIMESTAMP(compute_start_ns);
    // Run...
    if (!all_response_failed) {
        Execute(&responses, request_count, input_tensors, output_tensors);
    }
    
    uint64_t compute_end_ns = 0;
    SET_TIMESTAMP(compute_end_ns);

    if (!all_response_failed) {
        RESPOND_ALL_AND_SET_TRUE_IF_ERROR(responses, request_count, all_response_failed,
            ReadOutputTensors(
                total_batch_size, output_tensors, requests, request_count, &responses)
            );
    }
    uint64_t exec_end_ns = 0;
    SET_TIMESTAMP(exec_end_ns);
    // Send all the responses that haven't already been sent because of
    // an earlier error. Note that the responses are not set to nullptr
    // here as we need that indication below to determine if the request
    // we successful or not.
    for (auto& response : responses) {
        if (response != nullptr) {
            LOG_IF_ERROR(TRITONBACKEND_ResponseSend(
                    response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr),
                "failed to send PyTorch backend response");
        }
    }

    // Report statistics for each request.
    for (uint32_t r = 0; r < request_count; ++r) {
        auto& request = requests[r];
        LOG_IF_ERROR(
            TRITONBACKEND_ModelInstanceReportStatistics(
                TritonModelInstance(), request, (responses[r] != nullptr),
                exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns),
            "failed reporting request statistics");
        LOG_IF_ERROR(
            TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL),
            "failed releasing request");
    }
    if (!all_response_failed) {
        LOG_IF_ERROR(
            TRITONBACKEND_ModelInstanceReportBatchStatistics(
                TritonModelInstance(), total_batch_size,
                exec_start_ns, compute_start_ns, compute_end_ns, exec_end_ns),
            "failed reporting batch request statistics");
    }

}

TRITONSERVER_Error*
ModelInstanceState::VaildateInputs()
{
    
    //输入数量校验
    size_t expected_input_cnt = 0;

    triton::common::TritonJson::Value ios;
    if (model_state_->ModelConfig().Find("input", &ios)) {
        expected_input_cnt = ios.ArraySize();
    }

    uint32_t model_input_cnt = ascend_model_->GetInputTensorNum();

    bool is_dynamic_model = model_state_->GetDynamicInfo().dynamic_type != DynamicType::STATIC_BATCH;

    model_input_cnt = is_dynamic_model ? model_input_cnt - 1 : model_input_cnt;
    if (model_input_cnt != expected_input_cnt) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            (std::string("unable to load model '") + model_state_->Name() +
            "', configuration expects " + std::to_string(expected_input_cnt) +
            " inputs, model provides " + std::to_string(model_input_cnt)).c_str()
        );
    }

    // 判断输入数据的数据类型
    for (size_t i = 0; i < expected_input_cnt; i++) {
        auto data_type = ascend_model_->GetInputTensorDataType(i);
        triton::common::TritonJson::Value io;
        RETURN_IF_ERROR(ios.IndexAsObject(i, &io));
        // 获取对应的输入名称
        std::string io_name;
        RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
        model_input_names_.emplace_back(io_name);
        // 获取config里面的输入数据类型
        std::string io_dtype;
        RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype));
        if (ModelConfigDTypeToMxBase.find(io_dtype) == ModelConfigDTypeToMxBase.end() ||
            ModelConfigDTypeToMxBase[io_dtype] != data_type) {
            return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                ("model configuration input '" + io_name + "' datatype: '" + io_dtype +
                "' but model '" + model_state_->Name() +
                "' require " + MxBaseDTypeToModelConfig[data_type]).c_str()
            );
        }
    }
    return nullptr;
}

TRITONSERVER_Error*
ModelInstanceState::VaildateOutputs()
{
    // 获取输出的数量
    triton::common::TritonJson::Value ios;
    RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios));
    size_t expect_output_cnt = ios.ArraySize();
    if (expect_output_cnt == 0) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
            "model configuration must contain at least one output, none were specified."
        );
    }

    uint32_t model_output_cnt = ascend_model_->GetOutputTensorNum();
    if (model_output_cnt != expect_output_cnt) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            (std::string("unable to load model '") + model_state_->Name() +
            "', configuration expects " + std::to_string(expect_output_cnt) +
            " outputs, model provides " + std::to_string(model_output_cnt)).c_str()
        );
    }
    // 判断输入数据的数据类型
    for (size_t i = 0; i < expect_output_cnt; i++) {
        auto data_type = ascend_model_->GetOutputTensorDataType(i);
        triton::common::TritonJson::Value io;
        RETURN_IF_ERROR(ios.IndexAsObject(i, &io));
        // 获取对应的输入名称
        std::string io_name;
        RETURN_IF_ERROR(io.MemberAsString("name", &io_name));
        model_output_names_.emplace_back(io_name);

        // 获取config里面的输入数据类型
        std::string io_dtype;
        RETURN_IF_ERROR(io.MemberAsString("data_type", &io_dtype));
        if (ModelConfigDTypeToMxBase.find(io_dtype) == ModelConfigDTypeToMxBase.end() ||
            ModelConfigDTypeToMxBase[io_dtype] != data_type) {
            return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                ("model configuration output '" + io_name + "' datatype: '" + io_dtype +
                "' but model '" + model_state_->Name() +
                "' require " + MxBaseDTypeToModelConfig[data_type]).c_str()
            );
        }
    }
    return nullptr;
}

void
ModelInstanceState::Execute(
    std::vector<TRITONBACKEND_Response*>* responses, const uint32_t response_count,
    std::vector<MxBase::Tensor>& input_tensors, std::vector<MxBase::Tensor>& output_tensors)
{
    //pad
    std::vector<std::vector<uint32_t>> input_shapes;
    for (size_t i = 0; i < input_tensors.size(); i++) {
        input_shapes.emplace_back(input_tensors[i].GetShape());
    }
    std::vector<std::vector<uint32_t>> padding_shapes = input_shapes;
    TRITONSERVER_Error* server_ret = GetPaddingShape(padding_shapes);
    if (server_ret != nullptr) {
        SendErrorForResponses(responses, response_count, server_ret);
    }

    std::vector<MxBase::Tensor> input_tensors_pad;
    if (!IsSame(input_shapes, padding_shapes)) {
        LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, std::string("Input tensors padding start.").c_str());
        for (size_t i = 0; i < input_tensors.size(); i++) {
            MxBase::TensorDType data_type = input_tensors[i].GetDataType();
            // 在HOST侧创建tensor
            MxBase::Tensor pad_tensor(padding_shapes[i], data_type, -1);
            APP_ERROR ret = MxBase::Tensor::TensorMalloc(pad_tensor);
            if (ret != APP_ERR_OK) {
                SendErrorForResponses(responses, response_count,
                    TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                    std::string("failed to malloc buffer for padding tensor.").c_str())
                );
            }
            // 将内容置为0
            memset(pad_tensor.GetData(), 0, pad_tensor.GetByteSize());
            server_ret = TensorPadding(input_tensors[i], pad_tensor);
            if (server_ret != nullptr) {
                SendErrorForResponses(responses, response_count, server_ret);
            }
            input_tensors_pad.emplace_back(pad_tensor);
        }
        LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, std::string("Input tensors padding end.").c_str());
    } else {
        input_tensors_pad = input_tensors;
    }
    
    try {
        output_tensors = ascend_model_->Infer(input_tensors_pad);
    }
    catch (std::exception& ex) {
        SendErrorForResponses(responses, response_count,
            TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (
                "ascend execute failure: " + std::string(ex.what())).c_str())    
        );
    }
    if (output_tensors.empty()) {
        SendErrorForResponses(responses, response_count,
            TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
            std::string("ascend execute failure: output is empty.").c_str())
        );
    }

    for (size_t i = 0; i < output_tensors.size(); i++) {
        APP_ERROR ret = output_tensors[i].ToHost();
        if (ret != APP_ERR_OK) {
            SendErrorForResponses(responses, response_count,
                TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                std::string("failed to transform output tensors data to host.").c_str())
            );
        }

    }

    std::vector<std::vector<uint32_t>> output_shapes;
    for (size_t i = 0; i < output_tensors.size(); i++) {
        output_shapes.emplace_back(output_tensors[i].GetShape());
    }
    std::vector<std::vector<uint32_t>> crop_shapes = output_shapes;
        // 有batch的padding,可以裁剪
    if (dynamic_info_.dynamic_type == DynamicType::DYNAMIC_BATCH || input_format_ != DataFormat::ND) {
        uint32_t batchsize = input_shapes[0][0];
        for (size_t i = 0; i < crop_shapes.size(); i++) {
            crop_shapes[i][0] = batchsize;
        }
    }
    if(dynamic_info_.dynamic_type == DynamicType::DYNAMIC_HW) {
        ModelStrideInfo model_stride_info = model_state_->GetModelStrideInfo();
        LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, std::string("Start to get crop shape").c_str());
        if (!model_stride_info.min_output_gear.empty() && !model_stride_info.max_output_gear.empty()) {
            for (size_t i = 0; i < crop_shapes.size(); i++) {
                std::vector<uint32_t> index(2, 0);
                if (output_format_[i] == DataFormat::NCHW) {
                    index[0] = NCHW_HEIGHT_INDEX;
                    index[1] = NCHW_WIDTH_INDEX;
                } else if (output_format_[i] == DataFormat::NHWC) {
                    index[0] = NHWC_HEIGHT_INDEX;
                    index[1] = NHWC_WIDTH_INDEX;
                }   
                float stride_h = (model_stride_info.max_output_gear[2*i] - model_stride_info.min_output_gear[2*i]) *
                    1.0f / (model_stride_info.max_input_gear[0] - model_stride_info.min_input_gear[0]);

                float stride_w = (model_stride_info.max_output_gear[2*i+1] - model_stride_info.min_output_gear[2*i+1]) *
                    1.0f / (model_stride_info.max_input_gear[1] - model_stride_info.min_input_gear[1]);

                if (input_format_ == DataFormat::NCHW) {
                    crop_shapes[i][index[0]] = (uint32_t)(((float)input_shapes[i][NCHW_HEIGHT_INDEX] -
                        (float)model_stride_info.min_input_gear[0]) * stride_h + (float)model_stride_info.min_output_gear[2*i]);
                    crop_shapes[i][index[1]] = (uint32_t)(((float)input_shapes[i][NCHW_WIDTH_INDEX] -
                        (float)model_stride_info.min_input_gear[1]) * stride_w + (float)model_stride_info.min_output_gear[2*i+1]);
                } else if (input_format_ == DataFormat::NHWC) {
                    crop_shapes[i][index[0]] = (uint32_t)(((float)input_shapes[i][NHWC_HEIGHT_INDEX] -
                        (float)model_stride_info.min_input_gear[0]) * stride_h + (float)model_stride_info.min_output_gear[2*i]);
                    crop_shapes[i][index[1]] = (uint32_t)(((float)input_shapes[i][NHWC_WIDTH_INDEX] -
                        (float)model_stride_info.min_input_gear[1]) * stride_w + (float)model_stride_info.min_output_gear[2*i+1]);
                }
                
                LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("Set model '") + Name() + "' num i: " + std::to_string(i) +
                    " index j:" + std::to_string(index[0]) + " crop shape '" +
                    std::to_string(crop_shapes[i][index[0]]) + "'").c_str());
                
                LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, (std::string("Set model '") + Name() + "' num i: " + std::to_string(i) +
                    " index j:" + std::to_string(index[1]) + " crop shape '" +
                    std::to_string(crop_shapes[i][index[1]]) + "'").c_str());
            }
            
        }
    }
    std::vector<MxBase::Tensor> output_tensors_pad;
    if (!IsSame(output_shapes, crop_shapes)) {
        LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, std::string("Output tensors crop start.").c_str());
        for (size_t i = 0; i < output_tensors.size(); i++) {
            MxBase::TensorDType data_type = output_tensors[i].GetDataType();
            // 在HOST侧创建tensor
            MxBase::Tensor crop_tensor(crop_shapes[i], data_type, -1);
            APP_ERROR ret = MxBase::Tensor::TensorMalloc(crop_tensor);
            if (ret != APP_ERR_OK) {
                SendErrorForResponses(responses, response_count,
                    TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                    std::string("failed to malloc buffer for crop tensor.").c_str())
                );
            }
            // 将内容置为0
            memset(crop_tensor.GetData(), 0, crop_tensor.GetByteSize());
            server_ret = TensorCrop(output_tensors[i], crop_tensor);
            if (server_ret != nullptr) {
                SendErrorForResponses(responses, response_count, server_ret);
            }
            output_tensors_pad.emplace_back(crop_tensor);
        }
        LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, std::string("Output tensors crop end.").c_str());
        output_tensors.swap(output_tensors_pad);
    }
}

// 策略:将所有的输入数据(CPU)通过collector到一起,然后再送入推理
TRITONSERVER_Error*
ModelInstanceState::SetInputTensors(
    size_t total_batch_size, TRITONBACKEND_Request** requests, const uint32_t request_count,
    BackendInputCollector* collector, std::vector<MxBase::Tensor>& input_tensors,
    std::vector<const char*>& input_buffers)
{

    // 所有的request的输入的个数需要相同
    uint32_t input_count = 0;
    RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
    input_buffers.resize(input_count);
    for (uint32_t i = 0; i < input_count; i++) {
        TRITONBACKEND_Input* input;
        RETURN_IF_ERROR(TRITONBACKEND_RequestInputByIndex(requests[0], i, &input));
        const char* input_name;
        TRITONSERVER_DataType input_datatype;
        const int64_t* input_shape;
        uint32_t input_dims_count;
        RETURN_IF_ERROR(TRITONBACKEND_InputProperties(input,
            &input_name, &input_datatype, &input_shape, &input_dims_count, nullptr, nullptr)
        );

        // 按照input dim构造batchn后的输入
        std::vector<int64_t> batchn_shape(input_shape, input_shape + input_dims_count);
        // 如果支持合并batch,则将batch设置为总的batch数量
        if (supports_batching_) {
            batchn_shape[0] = total_batch_size;
        }
        std::vector<uint32_t> ascend_batchn_shape;
        for (size_t i = 0; i < batchn_shape.size(); i++) {          
            if (batchn_shape[i] <= 0 || batchn_shape[i] > UINT32_MAX) {
                return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
                    (std::string("model '") + model_state_->Name() +
                    "', inputs shape '" + std::to_string(batchn_shape[i]) +
                    "' is invalid.").c_str()
                );
            }
            ascend_batchn_shape.emplace_back((uint32_t)batchn_shape[i]);
        }
        // 此处默认按CPU来处理,PINNED这里是固定缓冲区,应该是用来实现memory share的
        std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_perference = {
            {TRITONSERVER_MEMORY_CPU_PINNED, 0},
            {TRITONSERVER_MEMORY_CPU, 0}
        };

        const char* input_buffer = nullptr;
        size_t batchn_byte_size = 0;
        TRITONSERVER_MemoryType memoty_type;
        int64_t memory_type_id = 0;
        // 为组batch的tensor分配一块新的内存区
        RETURN_IF_ERROR(collector->ProcessTensor(input_name, nullptr, 0, alloc_perference, &input_buffer,
            &batchn_byte_size, &memoty_type, &memory_type_id)
        );
        // 记录分配的内存
        input_buffers[i] = input_buffer;
        // 之前检查过config的内容,此处默认不检验了
        MxBase::TensorDType ascend_type = TrironDTypeToMxBase[input_datatype];
        
        // 由于使用用户内存创建,Tensor不会自动对用户内存进行清理
        MxBase::Tensor input_tensor((void*)input_buffer, ascend_batchn_shape, ascend_type);

        input_tensors.emplace_back(std::move(input_tensor));

    }
    // collector去初始化
    collector->Finalize();
    
    return nullptr;
}

TRITONSERVER_Error*
ModelInstanceState::ReadOutputTensors(
    size_t total_batch_size,
    const std::vector<MxBase::Tensor>& output_tensors,
    TRITONBACKEND_Request** requests, const uint32_t request_count,
    std::vector<TRITONBACKEND_Response*>* responses)
{

    BackendOutputResponder responder(
        requests, request_count, responses, model_state_->TritonMemoryManager(),
        model_state_->MaxBatchSize() > 0, false, nullptr);
    // 此处要求对输出的结果再做一次校验,但是om模型推理的结果是和给定的类型相同,
    // 而config的校验在初始化的时候已经做过了,目前省略了校验这一步
    for (size_t i = 0; i < output_tensors.size(); i++) {
        std::vector<int64_t> batchn_shape;
        auto shape = output_tensors[i].GetShape();
        for (auto item : shape) {
            batchn_shape.push_back((int64_t)item);
        }
        if (batchn_shape.size() == 0) {
            return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
                (std::string("output '") + std::to_string(i) + "' is invalid which is not supported.").c_str()
            );
        }
        auto dataType = output_tensors[i].GetDataType();
        TRITONSERVER_DataType output_dtype = MxBaseDTypeToTriton[dataType];
        const char* output_buffer = (const char*)output_tensors[i].GetData();
        responder.ProcessTensor(
            model_output_names_[i], output_dtype, batchn_shape, output_buffer,
            TRITONSERVER_MEMORY_CPU, 0);
    }
    responder.Finalize();
    return nullptr;
}

TRITONSERVER_Error*
ModelInstanceState::GetPaddingShape(std::vector<std::vector<uint32_t>>& padding_shapes)
{
    LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, std::string("Start to get padding shape").c_str());
    switch (dynamic_info_.dynamic_type)
    {
    case DynamicType::DYNAMIC_BATCH:
        for (size_t i = 0; i< padding_shapes.size(); i++) {
            RETURN_IF_ERROR(GetDynamicBatchPaddingShape(padding_shapes[i]));
        }
        break;
    case DynamicType::DYNAMIC_HW:
        for (size_t i = 0; i< padding_shapes.size(); i++) {
            RETURN_IF_ERROR(GetDynamicHWPaddingShape(padding_shapes[i]));
        }
        break;
    case DynamicType::DYNAMIC_DIMS:
        RETURN_IF_ERROR(GetDynamicDimsPaddingShape(padding_shapes));
        break;
    case DynamicType::DYNAMIC_SHAPE:
        LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
            std::string("Padding of dynamic shape model is not supported.").c_str());
        break;
    default:
        break;
    }
    // 没有分档的地方直接查询模型获取需要padding的shape
    for (size_t i = 0; i < padding_shapes.size(); i++) {
        auto model_shape = ascend_model_->GetInputTensorShape(i);
        for (size_t j = 0; j < padding_shapes[i].size(); j++) {
            if (padding_shapes[i][j] != model_shape[j] && model_shape[j] != -1) {
                padding_shapes[i][j] = model_shape[j];
                LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
                    (std::string("Set model '") + Name() + "' num i: " + std::to_string(i) +
                    " index j:" + std::to_string(j) + " padding shape '" +
                    std::to_string(padding_shapes[i][j]) + "'").c_str());
            }
        }
    }
    return nullptr;
}

TRITONSERVER_Error*
ModelInstanceState::GetDynamicBatchPaddingShape(std::vector<uint32_t>& input_shape)
{
    // 动态batch时设置最大档位
    auto dynamic_batch = dynamic_info_.dynamic_batch;
    std::sort(dynamic_batch.begin(), dynamic_batch.end());
    auto it = std::lower_bound(dynamic_batch.begin(), dynamic_batch.end(), input_shape[0]);
    // 没有大于输入的档位
    if (it == dynamic_batch.end()) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
                std::string("model input batch is '" + std::to_string(input_shape[0]) +
                "' but the maximum supported batch of the model is '" +
                std::to_string(dynamic_batch.back()) + "'").c_str());
    }
    input_shape[0] = (uint32_t)(*it);
    LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
        (std::string("Set model '") + Name() + "' batch padding shape: " +
        std::to_string(input_shape[0]) + "'").c_str());
    return nullptr;
}

TRITONSERVER_Error*
ModelInstanceState::GetDynamicHWPaddingShape(std::vector<uint32_t>& input_shape)
{
    auto dynamic_size = dynamic_info_.dynamic_size;
    std::vector<uint32_t> index(2, 0);
    if (input_format_ == DataFormat::NCHW) {
        index[0] = NCHW_HEIGHT_INDEX;
        index[1] = NCHW_WIDTH_INDEX;
    } else if (input_format_ == DataFormat::NHWC) {
        index[0] = NHWC_HEIGHT_INDEX;
        index[1] = NHWC_WIDTH_INDEX;
    } else {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
            std::string("Dynamic hw model input format should be NCHW or NHWC.").c_str());
    }
    std::vector<uint32_t> input_hw(2, 0);
    input_hw[0] = input_shape[index[0]];
    input_hw[1] = input_shape[index[1]];
    std::sort(dynamic_size.begin(), dynamic_size.end());
    auto it = std::lower_bound(dynamic_size.begin(), dynamic_size.end(), input_hw);
    if (it == dynamic_size.end()) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
            std::string("The maximum gear of the model is: H'" + std::to_string(dynamic_size.back()[0])
            + "' W'" + std::to_string(dynamic_size.back()[1]) + "'. Please check input shape").c_str());
    }
    size_t min_area = (*it)[0] * (*it)[1];
    auto min_it = it;
    for (; it != dynamic_size.end(); it++) {
        if ((*it)[1] >= input_hw[1]) {
            size_t temp_area = (*it)[0] * (*it)[1];
            if (temp_area < min_area) {
                min_area = temp_area;
                min_it = it;
            }
        }
    }
    input_shape[index[0]] = (uint32_t)(*min_it)[0];
    input_shape[index[1]] = (uint32_t)(*min_it)[1];
    LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
        (std::string("Set model '") + Name() + "' size padding shape h: " +
        std::to_string(input_shape[index[0]]) + "' w: '" + std::to_string(input_shape[index[1]]) + "'").c_str());
    return nullptr;
}

TRITONSERVER_Error*
ModelInstanceState::GetDynamicDimsPaddingShape(std::vector<std::vector<uint32_t>>& input_shapes)
{
    std::vector<uint32_t> check_shapes;
    for (size_t i = 0; i < input_shapes.size(); i++) {
        auto model_shape = ascend_model_->GetInputTensorShape(i);
        for (size_t j = 0; j < input_shapes[i].size(); j++) {
            if (model_shape[j] == -1) {
                check_shapes.emplace_back((uint32_t)input_shapes[i][j]);
            }
        }
    }
    auto dynamic_dims = dynamic_info_.dynamic_dims;
    std::sort(dynamic_dims.begin(), dynamic_dims.end());
    auto it = std::lower_bound(dynamic_dims.begin(), dynamic_dims.end(), check_shapes);
    if (it == dynamic_dims.end()) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
            std::string("Input shape is not meeting with the gear of model dims").c_str());
    }
    size_t index = 0;
    for (size_t i = 0; i < input_shapes.size(); i++) {
        auto model_shape = ascend_model_->GetInputTensorShape(i);
        for (size_t j = 0; j < input_shapes[i].size(); j++) {
            if (model_shape[j] == -1) {
                input_shapes[i][j] = (uint32_t)(*it)[index];
                index++;
            }
        }
    }
    return nullptr;
}

bool ModelInstanceState::IsSame(std::vector<std::vector<uint32_t>>& input_shapes,
    std::vector<std::vector<uint32_t>>& padding_shapes)
{
    for (size_t i = 0; i < input_shapes.size(); i++) {
        for (size_t j = 0; j < input_shapes[i].size(); j++) {
            if (input_shapes[i][j] != padding_shapes[i][j]) {
                return false; 
            }
        }
    }
    return true;
}

// 使用memory copy的方法进行内存拷贝(深拷贝),后续需要做优化
TRITONSERVER_Error*
ModelInstanceState::TensorPadding(MxBase::Tensor& input_tensor, MxBase::Tensor& padding_tensor)
{
    // 1, 3, 3, 3 -> 3, 6, 4, 4
    std::vector<uint32_t> input_shape = input_tensor.GetShape();
    std::vector<uint32_t> padding_shape = padding_tensor.GetShape();

    size_t input_dims = input_shape.size();
    size_t padding_dims = padding_shape.size();

    if (input_dims != padding_dims) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            (std::string("The shape sizes of input and padding are not equal. input shape size is '") +
            std::to_string(input_dims) + "' padding shape size is '" +
            std::to_string(padding_dims) + "'").c_str());
    }
    for (size_t i = 0; i < input_dims; i++) {
        if (input_shape[i] > padding_shape[i]) {
            return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            (std::string("Input shape is bigger than padding shape. input shape is '") +
            std::to_string(input_shape[i]) + "' padding shape is '" +
            std::to_string(padding_shape[i]) + "'").c_str());
        }
    }

    // 如果维度为1,则直接copy即可
    void* input_ptr = input_tensor.GetData();
    void* padding_ptr = padding_tensor.GetData();
    if (input_dims == 1) {
        memcpy(padding_ptr, input_ptr, input_tensor.GetByteSize());
        return nullptr;
    }
    size_t input_size = 1;
    size_t padding_size = 1;
    std::vector<size_t> input_stride;
    std::vector<size_t> padding_stride;
    // input_stride->3, 9, 27, 27  padding_stride->4, 16, 96, 288
    for (int64_t i = input_dims - 1; i >= 0; i--) {
        input_size *= input_shape[i];
        input_stride.emplace_back(input_size);
        padding_size *= padding_shape[i];
        padding_stride.emplace_back(padding_size);
    }
    // (u)int32 , float32 -> 4  (u)int8 -> 1  float64 -> 8  float16 ->2
    uint32_t byte_size = input_tensor.GetByteSize() / input_size;
    std::vector<size_t> index(input_dims - 1, 0);
    while (index[0] < input_shape[0]) {
        // 每次只copy最后一维度大小
        size_t input_offeset = 0;
        size_t padding_offeset = 0;
        for (size_t i = 0; i < index.size(); i++) {
            input_offeset += index[i] * input_stride[input_dims - 2 - i];
            padding_offeset += index[i] * padding_stride[padding_dims - 2 - i];
        }
        void* input_begin_ptr = (void*)((uint8_t*)input_ptr + input_offeset * byte_size);
        void* padding_begin_ptr = (void*)((uint8_t*)padding_ptr + padding_offeset * byte_size);
        memcpy(padding_begin_ptr, input_begin_ptr, input_shape[input_dims - 1] * byte_size);
        index[input_dims - 2]++;
        for (int64_t i = input_dims - 2; i > 0; i--) {
            if (index[i] == input_shape[i]) {
                index[i] = 0;
                index[i - 1] += 1;
            }
        }
        
    }
    return nullptr;
}

// 使用memory copy的方法进行内存拷贝(深拷贝),后续需要做优化
TRITONSERVER_Error*
ModelInstanceState::TensorCrop(MxBase::Tensor& input_tensor, MxBase::Tensor& crop_tensor)
{
    // 1, 3, 3, 3 -> 3, 6, 4, 4
    std::vector<uint32_t> input_shape = input_tensor.GetShape();
    std::vector<uint32_t> crop_shape = crop_tensor.GetShape();

    size_t input_dims = input_shape.size();
    size_t crop_dims = crop_shape.size();

    if (input_dims != crop_dims) {
        return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            (std::string("The shape sizes of input and crop are not equal. input shape size is '") +
            std::to_string(input_dims) + "' crop shape size is '" +
            std::to_string(crop_dims) + "'").c_str());
    }
    for (size_t i = 0; i < input_dims; i++) {
        if (input_shape[i] < crop_shape[i]) {
            return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG,
            (std::string("Input shape is bigger than crop shape. input shape is '") +
            std::to_string(input_shape[i]) + "' crop shape is '" +
            std::to_string(crop_shape[i]) + "'").c_str());
        }
    }

    // 如果维度为1,则直接copy即可
    void* input_ptr = input_tensor.GetData();
    void* crop_ptr = crop_tensor.GetData();
    if (input_dims == 1) {
        memcpy(crop_ptr, input_ptr, crop_tensor.GetByteSize());
        return nullptr;
    }
    size_t input_size = 1;
    size_t crop_size = 1;
    std::vector<size_t> input_stride;
    std::vector<size_t> crop_stride;
    // input_stride->3, 9, 27, 27  padding_stride->4, 16, 96, 288
    for (int64_t i = input_dims - 1; i >= 0; i--) {
        input_size *= input_shape[i];
        input_stride.emplace_back(input_size);
        crop_size *= crop_shape[i];
        crop_stride.emplace_back(crop_size);
    }
    // (u)int32 , float32 -> 4  (u)int8 -> 1  float64 -> 8  float16 ->2
    uint32_t byte_size = crop_tensor.GetByteSize() / crop_size;
    std::vector<size_t> index(crop_dims - 1, 0);
    while (index[0] < crop_shape[0]) {
        size_t input_offeset = 0;
        size_t crop_offeset = 0;
        for (size_t i = 0; i < index.size(); i++) {
            input_offeset += index[i] * input_stride[input_dims - 2 - i];
            crop_offeset += index[i] * crop_stride[crop_dims - 2 - i];
        }
        void* input_begin_ptr = (void*)((uint8_t*)input_ptr + input_offeset * byte_size);
        void* crop_begin_ptr = (void*)((uint8_t*)crop_ptr + crop_offeset * byte_size);
        // 每次只copy最后一维度大小
        memcpy(crop_begin_ptr, input_begin_ptr, crop_shape[crop_dims - 1] * byte_size);
        index[crop_dims - 2]++;
        for (int64_t i = crop_dims - 2; i > 0; i--) {
            if (index[i] == crop_shape[i]) {
                index[i] = 0;
                index[i - 1] += 1;
            }
        }
    }
    return nullptr;
}
}}};